{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from omegaconf import DictConfig\n",
    "from transformers import MistralForCausalLM\n",
    "\n",
    "from fusion_bench.method.mixture_of_experts.mixtral_upcycling import (\n",
    "    MixtralForCausalLMUpscalingAlgorithm,\n",
    ")\n",
    "from fusion_bench.utils import print_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9cc1b690077041ca8e5c3f8eff3bc4b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pretrained model:\n",
      "trainable params: 7.24B || all params: 7.24B || trainable%: 100.0000\n"
     ]
    }
   ],
   "source": [
    "# Load a pre-trained Mistral model\n",
    "pretrained_model = MistralForCausalLM.from_pretrained(\n",
    "    os.path.expanduser(\"path_to_mistral_model\")\n",
    ")\n",
    "print(\"Pretrained model:\")\n",
    "print_parameters(pretrained_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1405e7358b3b4225a017f23524324bdf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Initializing Mixtral model:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0195592c70974e419e6c349782c7e0dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upscaling layers:   0%|          | 0/32 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mixtral model:\n",
      "trainable params: 24.15B || all params: 24.15B || trainable%: 100.0000\n"
     ]
    }
   ],
   "source": [
    "# Define the configuration for Mixtral\n",
    "config = {\n",
    "    \"num_experts\": 4,  # Number of expert channels\n",
    "    \"experts_per_token\": 2,  # Experts to choose per token\n",
    "}\n",
    "\n",
    "# Initialize the upscaling algorithm\n",
    "upscaling_for_causal_lm_algorithm = MixtralForCausalLMUpscalingAlgorithm(\n",
    "    DictConfig(config)\n",
    ")\n",
    "\n",
    "# Run the upscaling process to get a Mixtral model\n",
    "mixtral_for_causal_lm_model = upscaling_for_causal_lm_algorithm.run(pretrained_model)\n",
    "\n",
    "print(\"Mixtral model:\")\n",
    "print_parameters(mixtral_for_causal_lm_model)\n",
    "\n",
    "# Save the upscaled Mixtral model\n",
    "mixtral_for_causal_lm_model.save_pretrained(\"path_to_save_mixtral_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
